# -*- encoding: utf-8 -*-
"""
@File    : tf_lambda_use.py
@Author  : lilong
@Time    : 2022/9/19 6:33 下午
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Lambda

# Lambda层的使用
x = np.array(
    [[1, 2],
     [3, 4],
     [5, 6]])

layer_1 = Lambda(lambda x: x ** 2)
x = layer_1(x)
print(x)

x_shape = tf.shape(x)
layer_2 = Lambda(lambda x: tf.reshape(x, (-1, x_shape[1], x_shape[0])))
z = layer_2(x)
print(z)
print('---'*20)
